{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5 - Multi-class Sentiment Analysis\n",
    "\n",
    "In all of the previous notebooks we have performed sentiment analysis on a dataset with only two classes, positive or negative. When we have only two classes our output can be a single scalar, bound between 0 and 1, that indicates what class an example belongs to. When we have more than 2 examples, our output must be a $C$ dimensional vector, where $C$ is the number of classes.\n",
    "\n",
    "In this notebook, we'll be performing classification on a dataset with 6 classes. Note that this dataset isn't actually a sentiment analysis dataset, it's a dataset of questions and the task is to classify what category the question belongs to. However, everything covered in this notebook applies to any dataset with examples that contain an input sequence belonging to one of $C$ classes.\n",
    "\n",
    "Below, we setup the fields, and load the dataset. \n",
    "\n",
    "The first difference is that we do not need to set the `dtype` in the `LABEL` field. When doing a mutli-class problem, PyTorch expects the labels to be numericalized `LongTensor`s. \n",
    "\n",
    "The second different is that we use `TREC` instead of `IMDB` to load the `TREC` dataset. The `fine_grained` argument allows us to use the fine-grained labels (of which there are 50 classes) or not (in which case they'll be 6 classes). You can change this how you please."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torchtext.legacy import data\n",
    "from torchtext.legacy import datasets\n",
    "import random\n",
    "\n",
    "SEED = 1234\n",
    "\n",
    "torch.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "TEXT = data.Field(tokenize = 'spacy',\n",
    "                  tokenizer_language = 'en_core_web_sm')\n",
    "\n",
    "LABEL = data.LabelField()\n",
    "\n",
    "train_data, test_data = datasets.TREC.splits(TEXT, LABEL, fine_grained=False)\n",
    "\n",
    "train_data, valid_data = train_data.split(random_state = random.seed(SEED))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's look at one of the examples in the training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': ['What', 'is', 'a', 'Cartesian', 'Diver', '?'], 'label': 'DESC'}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vars(train_data[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll build the vocabulary. As this dataset is small (only ~3800 training examples) it also has a very small vocabulary (~7500 unique tokens), this means we do not need to set a `max_size` on the vocabulary as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_VOCAB_SIZE = 25_000\n",
    "\n",
    "TEXT.build_vocab(train_data, \n",
    "                 max_size = MAX_VOCAB_SIZE, \n",
    "                 vectors = \"glove.6B.100d\", \n",
    "                 unk_init = torch.Tensor.normal_)\n",
    "\n",
    "LABEL.build_vocab(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we can check the labels.\n",
    "\n",
    "The 6 labels (for the non-fine-grained case) correspond to the 6 types of questions in the dataset:\n",
    "- `HUM` for questions about humans\n",
    "- `ENTY` for questions about entities\n",
    "- `DESC` for questions asking you for a description \n",
    "- `NUM` for questions where the answer is numerical\n",
    "- `LOC` for questions where the answer is a location\n",
    "- `ABBR` for questions asking about abbreviations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "defaultdict(None, {'HUM': 0, 'ENTY': 1, 'DESC': 2, 'NUM': 3, 'LOC': 4, 'ABBR': 5})\n"
     ]
    }
   ],
   "source": [
    "print(LABEL.vocab.stoi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As always, we set up the iterators."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
     ]
    }
   ],
   "source": [
    "BATCH_SIZE = 64\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
    "    (train_data, valid_data, test_data), \n",
    "    batch_size = BATCH_SIZE, \n",
    "    device = device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll be using the CNN model from the previous notebook, however any of the models covered in these tutorials will work on this dataset. The only difference is now the `output_dim` will be $C$ instead of $1$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class CNN(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n",
    "                 dropout, pad_idx):\n",
    "        \n",
    "        super().__init__()\n",
    "        \n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
    "        \n",
    "        self.convs = nn.ModuleList([\n",
    "                                    nn.Conv2d(in_channels = 1, \n",
    "                                              out_channels = n_filters, \n",
    "                                              kernel_size = (fs, embedding_dim)) \n",
    "                                    for fs in filter_sizes\n",
    "                                    ])\n",
    "        \n",
    "        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, text):\n",
    "        \n",
    "        #text = [sent len, batch size]\n",
    "        \n",
    "        text = text.permute(1, 0)\n",
    "                \n",
    "        #text = [batch size, sent len]\n",
    "        \n",
    "        embedded = self.embedding(text)\n",
    "                \n",
    "        #embedded = [batch size, sent len, emb dim]\n",
    "        \n",
    "        embedded = embedded.unsqueeze(1)\n",
    "        \n",
    "        #embedded = [batch size, 1, sent len, emb dim]\n",
    "        \n",
    "        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n",
    "            \n",
    "        #conv_n = [batch size, n_filters, sent len - filter_sizes[n]]\n",
    "        \n",
    "        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n",
    "        \n",
    "        #pooled_n = [batch size, n_filters]\n",
    "        \n",
    "        cat = self.dropout(torch.cat(pooled, dim = 1))\n",
    "\n",
    "        #cat = [batch size, n_filters * len(filter_sizes)]\n",
    "            \n",
    "        return self.fc(cat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define our model, making sure to set `OUTPUT_DIM` to $C$. We can get $C$ easily by using the size of the `LABEL` vocab, much like we used the length of the `TEXT` vocab to get the size of the vocabulary of the input.\n",
    "\n",
    "The examples in this dataset are generally a lot smaller than those in the IMDb dataset, so we'll use smaller filter sizes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_DIM = len(TEXT.vocab)\n",
    "EMBEDDING_DIM = 100\n",
    "N_FILTERS = 100\n",
    "FILTER_SIZES = [2,3,4]\n",
    "OUTPUT_DIM = len(LABEL.vocab)\n",
    "DROPOUT = 0.5\n",
    "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
    "\n",
    "model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Checking the number of parameters, we can see how the smaller filter sizes means we have about a third of the parameters than we did for the CNN model on the IMDb dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model has 841,806 trainable parameters\n"
     ]
    }
   ],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(f'The model has {count_parameters(model):,} trainable parameters')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll load our pre-trained embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1117, -0.4966,  0.1631,  ...,  1.2647, -0.2753, -0.1325],\n",
       "        [-0.8555, -0.7208,  1.3755,  ...,  0.0825, -1.1314,  0.3997],\n",
       "        [ 0.1638,  0.6046,  1.0789,  ..., -0.3140,  0.1844,  0.3624],\n",
       "        ...,\n",
       "        [-0.3110, -0.3398,  1.0308,  ...,  0.5317,  0.2836, -0.0640],\n",
       "        [ 0.0091,  0.2810,  0.7356,  ..., -0.7508,  0.8967, -0.7631],\n",
       "        [ 0.5831, -0.2514,  0.4156,  ..., -0.2735, -0.8659, -1.4063]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pretrained_embeddings = TEXT.vocab.vectors\n",
    "\n",
    "model.embedding.weight.data.copy_(pretrained_embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then zero the initial weights of the unknown and padding tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n",
    "\n",
    "model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n",
    "model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another different to the previous notebooks is our loss function (aka criterion). Before we used `BCEWithLogitsLoss`, however now we use `CrossEntropyLoss`. Without going into too much detail, `CrossEntropyLoss` performs a *softmax* function over our model outputs and the loss is given by the *cross entropy* between that and the label.\n",
    "\n",
    "Generally:\n",
    "- `CrossEntropyLoss` is used when our examples exclusively belong to one of $C$ classes\n",
    "- `BCEWithLogitsLoss` is used when our examples exclusively belong to only 2 classes (0 and 1) and is also used in the case where our examples belong to between 0 and $C$ classes (aka multilabel classification)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "model = model.to(device)\n",
    "criterion = criterion.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before, we had a function that calculated accuracy in the binary label case, where we said if the value was over 0.5 then we would assume it is positive. In the case where we have more than 2 classes, our model outputs a $C$ dimensional vector, where the value of each element is the beleief that the example belongs to that class. \n",
    "\n",
    "For example, in our labels we have: 'HUM' = 0, 'ENTY' = 1, 'DESC' = 2, 'NUM' = 3, 'LOC' = 4 and 'ABBR' = 5. If the output of our model was something like: **[5.1, 0.3, 0.1, 2.1, 0.2, 0.6]** this means that the model strongly believes the example belongs to class 0, a question about a human, and slightly believes the example belongs to class 3, a numerical question.\n",
    "\n",
    "We calculate the accuracy by performing an `argmax` to get the index of the maximum value in the prediction for each element in the batch, and then counting how many times this equals the actual label. We then average this across the batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def categorical_accuracy(preds, y):\n",
    "    \"\"\"\n",
    "    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
    "    \"\"\"\n",
    "    top_pred = preds.argmax(1, keepdim = True)\n",
    "    correct = top_pred.eq(y.view_as(top_pred)).sum()\n",
    "    acc = correct.float() / y.shape[0]\n",
    "    return acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The training loop is similar to before, without the need to `squeeze` the model predictions as `CrossEntropyLoss` expects the input to be **[batch size, n classes]** and the label to be **[batch size]**.\n",
    "\n",
    "The label needs to be a `LongTensor`, which it is by default as we did not set the `dtype` to a `FloatTensor` as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, iterator, optimizer, criterion):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    for batch in iterator:\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        predictions = model(batch.text)\n",
    "        \n",
    "        loss = criterion(predictions, batch.label)\n",
    "        \n",
    "        acc = categorical_accuracy(predictions, batch.label)\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The evaluation loop is, again, similar to before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, iterator, criterion):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "    \n",
    "        for batch in iterator:\n",
    "\n",
    "            predictions = model(batch.text)\n",
    "            \n",
    "            loss = criterion(predictions, batch.label)\n",
    "            \n",
    "            acc = categorical_accuracy(predictions, batch.label)\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "            epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def epoch_time(start_time, end_time):\n",
    "    elapsed_time = end_time - start_time\n",
    "    elapsed_mins = int(elapsed_time / 60)\n",
    "    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
    "    return elapsed_mins, elapsed_secs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we train our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 01 | Epoch Time: 0m 0s\n",
      "\tTrain Loss: 1.312 | Train Acc: 47.11%\n",
      "\t Val. Loss: 0.947 |  Val. Acc: 66.41%\n",
      "Epoch: 02 | Epoch Time: 0m 0s\n",
      "\tTrain Loss: 0.870 | Train Acc: 69.18%\n",
      "\t Val. Loss: 0.741 |  Val. Acc: 74.14%\n",
      "Epoch: 03 | Epoch Time: 0m 0s\n",
      "\tTrain Loss: 0.675 | Train Acc: 76.32%\n",
      "\t Val. Loss: 0.621 |  Val. Acc: 78.49%\n",
      "Epoch: 04 | Epoch Time: 0m 0s\n",
      "\tTrain Loss: 0.506 | Train Acc: 83.97%\n",
      "\t Val. Loss: 0.547 |  Val. Acc: 80.32%\n",
      "Epoch: 05 | Epoch Time: 0m 0s\n",
      "\tTrain Loss: 0.373 | Train Acc: 88.23%\n",
      "\t Val. Loss: 0.487 |  Val. Acc: 82.92%\n"
     ]
    }
   ],
   "source": [
    "N_EPOCHS = 5\n",
    "\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "for epoch in range(N_EPOCHS):\n",
    "\n",
    "    start_time = time.time()\n",
    "    \n",
    "    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
    "    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
    "    \n",
    "    end_time = time.time()\n",
    "\n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    \n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        torch.save(model.state_dict(), 'tut5-model.pt')\n",
    "    \n",
    "    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's run our model on the test set!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 0.415 | Test Acc: 86.07%\n"
     ]
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('tut5-model.pt'))\n",
    "\n",
    "test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
    "\n",
    "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similar to how we made a function to predict sentiment for any given sentences, we can now make a function that will predict the class of question given.\n",
    "\n",
    "The only difference here is that instead of using a sigmoid function to squash the input between 0 and 1, we use the `argmax` to get the highest predicted class index. We then use this index with the label vocab to get the human readable label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import spacy\n",
    "nlp = spacy.load('en_core_web_sm')\n",
    "\n",
    "def predict_class(model, sentence, min_len = 4):\n",
    "    model.eval()\n",
    "    tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
    "    if len(tokenized) < min_len:\n",
    "        tokenized += ['<pad>'] * (min_len - len(tokenized))\n",
    "    indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
    "    tensor = torch.LongTensor(indexed).to(device)\n",
    "    tensor = tensor.unsqueeze(1)\n",
    "    preds = model(tensor)\n",
    "    max_preds = preds.argmax(dim = 1)\n",
    "    return max_preds.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's try it out on a few different questions..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted class is: 0 = HUM\n"
     ]
    }
   ],
   "source": [
    "pred_class = predict_class(model, \"Who is Keyser Söze?\")\n",
    "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted class is: 3 = NUM\n"
     ]
    }
   ],
   "source": [
    "pred_class = predict_class(model, \"How many minutes are in six hundred and eighteen hours?\")\n",
    "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted class is: 4 = LOC\n"
     ]
    }
   ],
   "source": [
    "pred_class = predict_class(model, \"What continent is Bulgaria in?\")\n",
    "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted class is: 5 = ABBR\n"
     ]
    }
   ],
   "source": [
    "pred_class = predict_class(model, \"What does WYSIWYG stand for?\")\n",
    "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}